130

Applications in Natural Language Processing

The learnable weights and quantization parameters in the n-th module are updated by

minimizing the reconstruction errors. The proposed MREM can be optimized parallelly:

given previously trained modules, only weights and quantization parameters in the current

module are updated. Moreover, the number of modules N can be adjusted depending on

the memory constraint of computing resources. The flexibility of the number of transformer

layers ensures the proper trade-offbetween layer-wise correlation and memory overhead

of training devices can be achieved. Although a similar block-wise objective is previously

proposed in [137], it requires calculating second-order Hessian matrices for optimization,

which can be computationally prohibitive for large language models.

5.5.2

Model Parallel Strategy

Second, a new model parallel strategy is designed to accelerate the training process of

MREM. A common strategy is to optimize each module one by one. However, the training of

this strategy still needs a long time. Motivated by this, the authors propose a model parallel

strategy that allows all modules to be trained jointly without synchronizing with adjacent

partition modules by allocating each partitioned module to the individual computing device.

Specifically, every module is computed one after another in the first t0 step to construct

an input queue I, which contains t0 intermediate output results. For the n-th module,

its input queue comes from the previous module, i.e., It

n1 =



f 1

n1, f 2

n1, f 3

n1, . . . , f t0

n1



.

Then, parallel training takes place. Each module samples its input from the correspondingly

input queue and optimizes the loss defined by Eq. (5.10). Meanwhile, the input queue is also

updated with the first-in-first-out rule throughout the training. Once a module produces

its output, the results will be fed into the following input queue. In the backward pass, the

gradients can propagate locally within each module, without affecting its predecessors. As

a result, such a design can avoid the load imbalance issue from straggler modules, bringing

nearly the theoretical N× speed-up if deploying in N GPU. Such results are superior to

previous data parallel [131] or model parallel [96] techniques.

5.5.3

Annealed Teacher Forcing

Third, the authors design an annealed teacher forcing for the parallel strategy. They find

that the naive parallel training suffers from the propagation of reconstruction error since

each quantized module passes the quantization error to its successors before being fully

optimized. In particular, all modules get optimized simultaneously instead of sequentially

in the parallel strategy. The next module takes the output from the input queue before

its predecessor is fully optimized. Therefore, the predecessor’s reconstruction error will

propagate to the following modules before it is sufficiently minimized. To solve this problem,

the proposed annealed teacher forcing is similar to the method in [246]. The full-precision

module provides clean signals to the next quantized module. This breaks the reconstruction

error propagation and further improves the performance of the parallel strategy. Specifically,

the output fn from the n-th full-precision module serves as the clean input to the (n+1)-th

quantized module to substitute the original ˆfn that comes from the quantized module. As

a result, fn can stop the propagation of the accumulated error on the quantized module.

Nevertheless, such an approach breaks the connection to previous quantized modules and

may suffer from forward inconsistency between training and inference for the quantized

model. To solve this problem, the actually input to (n + 1)-th quantized module is the